import os
import numpy as np
import pandas as pd
import tkinter as tk
from tkinter import filedialog, simpledialog
import matplotlib.pyplot as plt
import mne
import gc

# --- Show ICA sources and topomaps sequentially ---
def plot_ica_combined(ica, raw):
    print("Displaying ICA time series (sources)...")
    ica.plot_sources(raw, show_scrollbars=True, show=True, block=True)

    print("Displaying ICA topographic components...")
    ica.plot_components(show=True, sphere=(0., 0., 0., 0.09))

# --- Select folder ---
root = tk.Tk()
root.withdraw()
input_folder = filedialog.askdirectory(title="Select Folder Containing EEG and Event Files")
if not input_folder:
    print("No folder selected. Exiting.")
    exit()

# --- Match files ---
all_files = os.listdir(input_folder)
eeg_files = [f for f in all_files if 'raw' in f.lower() and f.endswith('.txt')]
event_files = [f for f in all_files if 'event' in f.lower() and f.endswith('.txt')]

def normalize(name):
    name = name.lower()
    for token in ['raw', 'eventprobe', 'event', 'nofeedback', '.txt', ' ', '_', '-']:
        name = name.replace(token, '')
    return name.strip()

paired_files = []
for eeg_file in eeg_files:
    eeg_base = normalize(eeg_file)
    matched = None
    for event_file in event_files:
        event_base = normalize(event_file)
        if eeg_base in event_base or event_base in eeg_base:
            matched = (eeg_file, event_file)
            print(f"[MATCH] {eeg_file} <--> {event_file} [base: {eeg_base} ~ {event_base}]")
            break
    if matched:
        paired_files.append(matched)
    else:
        print(f"[WARNING] No event file matched for {eeg_file} (base: {eeg_base})")

if not paired_files:
    print("No valid EEG/Event pairs found.")
    exit()

# --- Constants ---
sfreq = 256
electrodes = ["POz", "Fz", "Cz", "C3", "C4", "F3", "F4", "P3", "P4"]
expected_len = 205
valid_types = {"correct", "error"}

# --- Process each file pair ---
for eeg_file, event_file in paired_files:
    print("\nProcessing", eeg_file, "and", event_file)
    base_name = eeg_file.replace('_raw.txt', '').replace(' raw.txt', '').replace('Raw.txt', '').replace('.txt', '')

    eeg_path = os.path.join(input_folder, eeg_file)
    event_path = os.path.join(input_folder, event_file)

    # --- Load EEG ---
    eeg_df = pd.read_csv(eeg_path, sep='\t', header=None)
    eeg_df.columns = electrodes
    eeg_data = eeg_df.values.T
    info = mne.create_info(ch_names=electrodes, sfreq=sfreq, ch_types='eeg')
    raw = mne.io.RawArray(eeg_data, info)
    montage = mne.channels.make_standard_montage("standard_1020")
    raw.set_montage(montage, match_case=False, on_missing='ignore')

    # --- Filtering ---
    raw.filter(0.2, 40, fir_design='firwin')
    raw.notch_filter(60, fir_design='firwin')

    # --- ICA ---
    raw_for_ica = raw.copy().filter(1.0, None)
    ica = mne.preprocessing.ICA(n_components=min(len(electrodes), 20), random_state=42, max_iter=800)
    ica.fit(raw_for_ica)

    # --- Combined ICA plots ---
    plot_ica_combined(ica, raw_for_ica)

    print("Identifying ocular ICA components using Fz...")
    try:
        eog_inds, _ = ica.find_bads_eog(raw_for_ica, ch_name='Fz', threshold=2.0)
        print("Suggested components to remove:", eog_inds)
    except Exception as e:
        print("Could not auto-detect ocular components:", e)
        eog_inds = []

    suggestion = ",".join(str(i) for i in eog_inds)
    prompt = f"{base_name}: Suggested components to remove are: {suggestion}\nEnter components to exclude (comma-separated):"
    comp_input = simpledialog.askstring("ICA Component Confirmation", prompt)
    try:
        selected = [int(x.strip()) for x in comp_input.split(',')] if comp_input else []
        ica.exclude = selected
        print("Components selected for exclusion:", selected)
    except Exception as e:
        print("Invalid input. No components excluded.")
        ica.exclude = []

    raw_clean = ica.apply(raw.copy())
    raw_data = raw_clean.get_data()

    # --- Load and filter events ---
    event_df = pd.read_csv(event_path, sep='\t')
    event_df.columns = event_df.columns.str.lower()
    event_df = event_df[event_df['type'].isin(valid_types)]

    # --- Epoching ---
    epochs = []
    for _, row in event_df.iterrows():
        latency_ms = int(row['latency'])
        trial_type = row['type']
        start_idx = int((latency_ms - 200) * sfreq / 1000)
        end_idx = int((latency_ms + 600) * sfreq / 1000)
        if start_idx < 0 or end_idx > raw_data.shape[1]:
            continue
        try:
            epoch = raw_data[:, start_idx:end_idx]
        except Exception as e:
            print(f"Skipping epoch due to error: {e}")
            continue
        if epoch.shape[1] != expected_len:
            continue
        if np.any(np.ptp(epoch, axis=1) > 150):
            continue
        epochs.append({'type': trial_type, 'start_ms': latency_ms, 'data': epoch})

    if not epochs:
        print(base_name, ": No accepted epochs after rejection.")
        continue

    # --- Save ---
    event_id = {t: idx + 1 for idx, t in enumerate(sorted(valid_types))}
    events = np.array([[int(ep['start_ms'] * sfreq / 1000), 0, event_id[ep['type']]] for ep in epochs])
    epochs_array = np.array([ep['data'] for ep in epochs])
    epochs_mne = mne.EpochsArray(epochs_array, info=raw_clean.info, events=events, event_id=event_id, tmin=-0.2, baseline=(-0.2, 0))

    fif_path = os.path.join(input_folder, f"{base_name}_correct_error_cleaned-epo.fif")
    npy_path = os.path.join(input_folder, f"{base_name}_correct_error_cleaned_epochs.npy")
    epochs_mne.save(fif_path, overwrite=True)
    np.save(npy_path, epochs_array)

    print("Saved cleaned epochs to:", fif_path, "and", npy_path)

    # --- Cleanup ---
    del raw, raw_for_ica, raw_clean, raw_data, ica, eeg_df, event_df, epochs_array, epochs_mne, epochs
    gc.collect()

print("\nBatch preprocessing completed.")
